# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define CodeLine object."""
import re
from typing import List, Tuple

from mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords


class Fragment:
    """
    Fragment definition for MindSpore code generation.

    Args:
        data_entity (dict): Required data by operations. The format of `data_entity` is as follow:
            {
                "var1": {
                    "metadata": {  # ONNX Metadata
                        "operation": "Conv2d",
                        "source": "conv_pw_13/Conv2D",
                        "attributes": {
                            # Put original onnx attributes here.
                        }
                    },
                    "variable_name": None,
                    "inputs": [],
                    "output_type": "tensor" | "array",
                    "args": {"in_channels": 768, "out_channels": 1024},
                    "trainable_params": {"weight": "Parameter(Tensor(GLOBAL_W[NAME]))"}
                },
                "var2": {
                    "variable_name": "pad",
                    "args": {"padding": [0, 1, 1, 0], "mode": "SAME"}
                }
            }
        code_template (dict): Code template generated by mapper. The format of `code_template` is as follow:
            {
                "var1": {
                    "init": [
                        "self.{var1} = nn.Conv2d(in_channels={in_channels})",
                        "self.{var1}.weight = {weight}"
                    ],
                    "construct": [
                        "opt_{var1} = self.{var1}({inputs}[, extra])"
                    ]
                },
                "var2": {
                    "init": [
                        "self.{var2} = nn.Pad(padding={padding}, mode={mode})"
                    ],
                    "construct": [
                        "opt_{var2} = self.{var2}(opt_{var1}[, extra])"
                    ]
                }
            }
        outputs (list[str]): Outputs name slot list.
        outputs_mapping (tuple): Outputs index mapping between ir node and MindSpore operation.
    """

    def __init__(self, data_entity: dict, code_template: dict, outputs: List[str], outputs_mapping):
        self.exchange_msg = data_entity
        self._code_template = code_template
        self.inputs = []
        self._outputs = outputs
        self.outputs_mapping = outputs_mapping
        self.format_args = dict()

    def _get_outputs(self):
        """
        Get outputs of the code snippet.

        Returns:
            list[str], outputs of current code block.
        """
        outputs = []
        variables = {
            k: self.exchange_msg[k][ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]
            for k in self.exchange_msg if k != ExchangeMessageKeywords.METADATA.value
        }
        for o in self._outputs:
            extractor = r".*\{(?P<var>.+)\}.*"
            var_def = re.match(extractor, o)
            if not var_def:
                raise ValueError(f"Output variable name {o} is illegal.")
            outputs.append(
                (
                    o.format(**variables),
                    self.exchange_msg[var_def.group("var")][
                        ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value]
                )
            )
        return outputs

    def get_outputs_by_idx(self, idx, inner_idx=-1):
        """Get outputs by idx."""
        outputs = self._get_outputs()
        opt, opt_type = outputs[idx]
        if opt_type == ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value:
            return f"{opt}[{inner_idx}]"
        return opt

    @staticmethod
    def create_parameter(weight_shape, weight_dtype):
        """Create a parameter code line."""
        return f"Parameter(Tensor(np.random.uniform(0, 1, {weight_shape}).astype(np.{weight_dtype})), " \
               f"name=None)"

    def __call__(self) -> Tuple[List[str], List[str]]:
        """
        Define parameter rewrite function.

        Returns:
            tuple[list[str], list[str]], init statement and construct statement.
        """
        init_stats, call_stats = [], []
        precursor_node_var = [None, None]
        for op_var, template in self._code_template.items():
            if ExchangeMessageKeywords.VariableScope.value.INPUTS.value not in self.exchange_msg[op_var]:
                # It's possible inputs and precursor node both exists.
                self.exchange_msg[op_var][ExchangeMessageKeywords.VariableScope.value.ARGS.value][
                    precursor_node_var[0]] = precursor_node_var[1]
            for tpl in template[TemplateKeywords.INIT.value]:
                init_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl)
                init_stats.append(init_stat)
            for tpl in template[TemplateKeywords.CONSTRUCT.value]:
                call_stat = self._rewrite(op_var, self.exchange_msg[op_var], tpl)
                call_stats.append(call_stat)
            precursor_node_var = op_var, self.exchange_msg[op_var].get(
                ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value)
        return init_stats, call_stats

    def register_parameter(self, var, line):
        """Append a new parameter into template."""
        self._code_template[var][TemplateKeywords.INIT.value].append(line)

    @staticmethod
    def _rewrite(var, data, template: str) -> str:
        """
        Backfill data into code template.

        Args:
            var (str): Current operation variable name.
            data (dict): Data to be written.
            template (str): Code template.

        Returns:
            str, single code line.
        """
        rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]}
        if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data:
            group_inputs = ExchangeMessageKeywords.VariableScope.value.GROUP_INPUTS.value
            if group_inputs in data:
                input_added_list = []
                input_index = 0
                group_id = 0
                while input_index < len(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]):
                    if group_id < len(data[group_inputs]) and input_index in data[group_inputs][group_id]:
                        code_pattern = data[group_inputs][group_id][2] if len(
                            data[group_inputs][group_id]) > 2 else "{%}"
                        input_added = ", ".join(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]
                                                [data[group_inputs][group_id][0]:
                                                 data[group_inputs][group_id][1] + 1])

                        input_added = code_pattern.format(**{"%": input_added})
                        input_added_list.append(input_added)
                        input_index = data[group_inputs][group_id][1] + 1
                        group_id += 1
                        continue
                    input_added_list.append(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]
                                            [input_index])
                    input_index += 1

                if len(input_added_list) != 1:
                    rewrite_data.update(
                        {f"{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}_{idx}": input_added for
                         idx, input_added in enumerate(input_added_list)})
                else:
                    rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = input_added_list[0]
            else:
                rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join(
                    data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value])
        if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data:
            rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value])
        if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data:
            rewrite_params = {
                f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value].get(slot)
                for slot in data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value]
            }
            rewrite_data.update(rewrite_params)
        rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value])
        template = template.format(**{
            k: str(rewrite_data[k]) for k in rewrite_data
        })
        return template.format(**{
            k: str(rewrite_data[k]) for k in rewrite_data
        })
